import torch

from attacks import Attack
import torch.nn.functional as F

from constants import DEVICE
from utils import cross_entropy_loss, de_normalization, normalization


class PGD(Attack):
    """ PGD """

    def __init__(self, model, eps=16 / 255, steps=10):
        """
        :param model: DNN model
        :param eps: the maximum perturbation
        :param steps: the number of iterations
        """
        super().__init__("PGD", model)
        self.eps = eps
        self.steps = steps
        self.alpha = self.eps / self.steps

    def forward(self, images, labels):
        targets = F.one_hot(labels.type(torch.int64), 1000).float().to(DEVICE)
        images_de_normalized = de_normalization(images)
        images_min = torch.clamp(images_de_normalized - self.eps, min=0.0, max=1.0)
        images_max = torch.clamp(images_de_normalized + self.eps, min=0.0, max=1.0)

        images_de_normalized_uniform = images_de_normalized + torch.nn.init.uniform_(
            torch.empty_like(images_de_normalized), -self.eps, self.eps)
        images_de_normalized_uniform = torch.clamp(images_de_normalized_uniform, min=images_min, max=images_max)
        adv = normalization(images_de_normalized_uniform).clone()
        for _ in range(self.steps):
            logits = self.model(adv)
            loss = cross_entropy_loss(logits, targets)
            grad = torch.autograd.grad(loss, adv)[0]

            adv_de_normalized = de_normalization(adv)
            adv_de_normalized = torch.clamp(adv_de_normalized + self.alpha * torch.sign(grad), min=images_min,
                                            max=images_max)
            adv = normalization(adv_de_normalized)

        return adv
